Utilities for the models of this thesis

comment

Visualization

Computation

For \(p \in \mathbf R^{k}_{>0}\) with \(\sum_{i = 1}^k p_{i} = 1\), let \(\log q_i = \log \frac{p_{i}}{p_{k}}\) for \(i = 1, \dots, k -1\). Then \[ p_{k} = \frac{1}{1 + \sum_{i = 1}^{k-1}q_{i}}, \] so $$ p_{i} = q_{i} p_{k} = .

$$

Another parametrization takes consecutive conditonal probabilities, using logits to make the problem unconstrained.

Thus for \(p\in \mathbf R^k\) we have \[ q_{i} = \frac{p_{i}}{1 - \sum_{j = 1}^{i - 1} p_{j}} = \frac{p_{i}}{\sum_{j = i}^k p_{j}}, \] for \(i = 1, \dots, k - 1\) (\(q_k\) is \(1\) and can be discarded).

Then for \(i = 1, \dots, k\) \[ p_{i} = q_{i} \prod_{j = 1}^{i - 1}(1 - q_j). \]

checking the derivative

we have \[ \partial_{\operatorname{logit} (q_{k})}(p_{i}) = \partial_{\operatorname{logit} (q_{i})} \left( \operatorname{expit}(\operatorname{logit}(q_{i})) \prod_{j= 1}^{i - 1} (1 - \operatorname{expit}(\operatorname{logit}(q_{j}))) \right) = \begin{cases} p_{i} (1 - q_{k}) & k = i \\ -p_{i}q_{k} & k < i \\ 0 & \text{else} \end{cases} \]

from jax import vmap, jvp
from functools import partial


def grad_from_consecutive_logits(
    primals: Float[Array, "k-1"], tangents: Float[Array, "k-1"]
) -> Float[Array, "k-1"]:
    (l,) = primals.shape
    jac = jnp.zeros((l + 1, l))
    p = from_consecutive_logits(primals)
    q = jsp.special.expit(primals)

    for i in range(l):
        jac = jac.at[i, i].set(p[i] * (1 - q[i]))
        for k in range(i):
            jac = jac.at[i, k].set(-p[i] * q[k])
    jac = jac.at[l, :].set(-jnp.prod(1 - q) * q)
    return p, jac @ tangents


key = jrn.PRNGKey(0)
key, subkey = jrn.split(key)
rand_primal = jrn.normal(subkey, (5,))
key, subkey = jrn.split(key)
rand_tangent = jrn.normal(subkey, (5,))

(
    jvp(from_consecutive_logits, (rand_primal,), (rand_tangent,))[1]
    - grad_from_consecutive_logits(rand_primal, rand_tangent)[1]
)


# relative error
def rel_error(a, b):
    return jnp.abs(a - b) / (jnp.abs(a) + jnp.abs(b) + 1e-10)


fct.test_close(
    rel_error(
        jvp(from_consecutive_logits, (rand_primal,), (rand_tangent,))[1],
        grad_from_consecutive_logits(rand_primal, rand_tangent)[1],
    ),
    jnp.zeros_like(6),
)

Exporting results